from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from ChatGLM_new import tongyi_llm
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import chain

def length_function(text):
    return len(text)


def _multiple_length_function(text1, text2):
    return len(text1) * len(text2)


def multiple_length_function(_dict):
    return _multiple_length_function(_dict["text1"], _dict["text2"])


model = tongyi_llm

prompt = ChatPromptTemplate.from_template("what is {a} + {b}")

chain1 = prompt | model

chain2 = (
    {
        "a": itemgetter("foo") | RunnableLambda(length_function),
        "b": {"text1": itemgetter("foo"), "text2": itemgetter("bar")} | RunnableLambda(multiple_length_function),
    }
    | prompt
    | model
)

print(chain2.invoke({"foo": "bar", "bar": "gah"}))


##########################################@chain##########################################
prompt1 = ChatPromptTemplate.from_template("告诉我一个关于的笑话 {topic}")
prompt2 = ChatPromptTemplate.from_template("这个笑话的主题是什么: {joke}")
@chain
def custom_chain(text):
    prompt_val1 = prompt1.invoke({"topic": text})
    print(prompt_val1)
    output1 = tongyi_llm.invoke(prompt_val1)
    parsed_output1 = StrOutputParser().invoke(output1)
    chain2 = prompt2 | tongyi_llm | StrOutputParser()
    return chain2.invoke({"joke": parsed_output1})
#print(custom_chain.invoke("bears"))

##########################################链条中的自动强制######################
prompt = ChatPromptTemplate.from_template("给我讲一个故事关于 {topic}")
model = tongyi_llm
chain_with_coerced_function = prompt | model | (lambda x: x.content[:500])
#print(chain_with_coerced_function.invoke({"topic": "熊"}))


import json
from langchain_core.runnables import RunnableConfig

def parse_or_fix(text: str, config: RunnableConfig):
    fixing_chain = (
        ChatPromptTemplate.from_template(
            "Fix the following text:\n\n```text\n{input}\n```\nError: {error}"
            " Don't narrate, just respond with the fixed data."
        )
        | model
        | StrOutputParser()
    )
    for _ in range(3):
        try:
            return json.loads(text)
        except Exception as e:
            text = fixing_chain.invoke({"input": text, "error": e}, config)
    return "Failed to parse"


from langchain_community.callbacks import get_openai_callback

# with get_openai_callback() as cb:
#     output = RunnableLambda(parse_or_fix).invoke(
#         "{foo: bar}", {"tags": ["my-tag"], "callbacks": [cb]}
#     )
#     print(output)
#     print(cb)